from voxcell.cell_collection import CellCollection
import numpy as np
import seaborn as sns
import pandas as pd
from scipy.stats import linregress
import matplotlib.pyplot as plt
from morph_tool.neuron_surface import get_NEURON_surface
from neurom import load_morphology
from neurom import get
from neurom import NeuriteType
import yaml
from pathlib import Path
from morph_tool.morphdb import MorphDB

from emodel_generalisation.exemplars import generate_exemplars

if __name__ == "__main__":
    circuit_path = Path(
        # "/gpfs/bbp.cscs.ch/project/proj55/iavarone/releases/circuits/O1/2019-11-19_sonata_Zenodo/"
        "../circuit/"
    )
    df = CellCollection.load(
        circuit_path / "networks" / "nodes" / "thalamus_neurons" / "nodes.h5"
    ).as_dataframe()
    df = df[df.mtype == "Rt_RC"]
    circuit_morphs = df.morphology.unique().tolist()
    print("# unique clones:", len(circuit_morphs))

    release_path = "../morphology_release"
    df = MorphDB.from_neurondb(release_path + "/neuronDB.xml", morphology_folder=release_path).df
    df = df[df.mtype == "Rt_RC"]
    used_morphs = []
    for m in df.name:
        for m2 in circuit_morphs:
            if m in m2:
                used_morphs.append(m)
    print("# used repaired:", len(set(used_morphs)))
    df = df[df.name.isin(used_morphs)].reset_index(drop=True)
    df["path"] = df["path"].apply(lambda p: str(p))
    df.to_csv("morphologies.csv")
    exemplar_data = generate_exemplars(df)
    with open("exemplar_data.yaml", "w") as f:
        yaml.dump(exemplar_data, f)

    soma = []
    basals = []
    bos = []
    n_neurites = []
    areas = []
    for gid in df.index:
        m = load_morphology(df.loc[gid, "path"])
        soma.append(get_NEURON_surface(df.loc[gid, "path"]))
        basals.append(get("total_area", m, neurite_type=NeuriteType.basal_dendrite))
        bos.append(max(get("section_branch_orders", m, neurite_type=NeuriteType.basal_dendrite)))
        n_neurites.append(get("number_of_neurites", m, neurite_type=NeuriteType.basal_dendrite))
        areas.append(get("total_area", m, neurite_type=NeuriteType.basal_dendrite))
    plt.figure(figsize=(3, 3))
    res = linregress(soma, basals)
    print(res)
    plt.scatter(soma, basals)
    soma = np.array(soma)
    plt.plot(soma, res.intercept + res.slope * soma, "r")

    plt.xlabel("soma area")
    plt.ylabel("basal area")
    plt.tight_layout()
    plt.savefig("basal_soma.pdf")
    df = pd.DataFrame()
    df["total_area"] = areas
    df["n_neurites"] = n_neurites
    df["max_branch_order"] = bos

    fig, axs = plt.subplots(1, 3, figsize=(2.0 * 3, 3))
    for f, ax in zip(df.columns, axs):
        sns.stripplot(data=df[f], ax=ax, c="k")
        sns.boxplot(data=df[f], ax=ax, showfliers=False, fill=False, color="k")
    plt.tight_layout()

    plt.savefig("morphometrics.pdf")
